# Copyright (c) HySoP 2011-2024
#
# This file is part of HySoP software.
# See "https://particle_methods.gricad-pages.univ-grenoble-alpes.fr/hysop-doc/"
# for further info.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import sympy as sm
from hysop.tools.numpywrappers import npw
from hysop.tools.htypes import first_not_None
from hysop.symbolic import Symbol, Dummy, subscript
from hysop.tools.sympy_utils import (
sstr,
sstrrepr,
latex as _latex,
UnevaluatedExpr,
UnsplittedExpr,
)
from contextlib import contextmanager
[docs]
class ValueHolderI:
"""
Interface for classes that may hold a value that
can be replaced in sympy expressions.
"""
def __new__(cls, *args, **kwds):
return super().__new__(cls, *args, **kwds)
def __init__(self, *args, **kwds):
super().__init__(*args, **kwds)
[docs]
def get_holded_value(self):
"""Get holded value, defaults to None."""
return None
[docs]
@classmethod
def get_holded_values(cls, expr):
replace = {}
def collect(expr):
if isinstance(expr, ValueHolderI):
val = expr.get_holded_value()
if val is not None:
replace[expr] = val
elif isinstance(expr, sm.Expr):
for e in expr.args:
collect(e)
collect(expr)
return replace
[docs]
@classmethod
def replace_holded_values(cls, expr):
replace = cls.get_holded_values(expr)
try:
return expr.xreplace(replace)
except AttributeError:
return expr
[docs]
class ScalarDataViewHolder(ValueHolderI):
def __new__(cls, holded_data_ref=None, holded_data_access=None, **kwds):
if (
isinstance(holded_data_ref, npw.ndarray)
and (holded_data_access is None)
and (holded_data_ref.size == 1)
):
holded_data_access = (0,)
obj = super().__new__(cls, **kwds)
obj._holded_value_ref = holded_data_ref
obj._holded_data_access = holded_data_access
return obj
def __init__(self, holded_data_ref=None, holded_data_access=None, **kwds):
super().__init__(**kwds)
[docs]
def get_holded_value(self):
if self._holded_value_ref is None:
return None
elif self._holded_data_access is None:
return self._holded_value_ref
elif callable(self._holded_data_access):
return self._holded_data_access(self._holded_value_ref)
else:
return self._holded_value_ref[self._holded_data_access]
def _hashable_content(self):
"""See sympy.core.basic.Basic._hashable_content()"""
hc = super()._hashable_content()
hc += (
id(self._holded_value_ref),
self._holded_data_access,
)
return hc
[docs]
class ScalarBaseTag:
"""Tag for object that can be inserted as element of tensors."""
def __new__(cls, idx=None, **kwds):
obj = super().__new__(cls, **kwds)
obj._idx = idx
return obj
def __init__(self, idx=None, **kwds):
super().__init__(**kwds)
@property
def idx(self):
return self._idx
def _hashable_content(self):
"""See sympy.core.basic.Basic._hashable_content()"""
hc = super()._hashable_content()
hc += (self._idx,)
return hc
[docs]
class ScalarBase(ScalarDataViewHolder, ScalarBaseTag):
"""Base for symbolic scalars."""
def __new__(cls, name, value=None, view=None, **kwds):
if value is not None:
assert kwds.get("holded_data_ref", None) is None
kwds["holded_data_ref"] = value
if view is not None:
assert kwds.get("holded_data_access", None) is None
kwds["holded_data_access"] = view
obj = super().__new__(cls, name=name, **kwds)
obj._iterable = False
return obj
def __init__(self, name, value=None, view=None, **kwds):
super().__init__(name=name, **kwds)
[docs]
def vreplace(self):
"""Call ValueHolderI.replace_holded_values on self."""
return self.replace_holded_values(self)
def __getitem__(self, key):
assert key == 0
return self
[docs]
class TensorBase(npw.ndarray):
"""
Base for symbolic tensors.
A tensor is a read-only npw.ndarray subclass containing symbolic scalars
or symbolic expressions.
"""
__array_priority__ = 1.0
[docs]
def __new__(
cls,
shape,
init=None,
name=None,
pretty_name=None,
scalar_cls=None,
scalar_kwds=None,
make_scalar_kwds=None,
value=None,
set_read_only=True,
dtype=object,
**kwds,
):
"""Create a new TensorBase."""
set_read_only = first_not_None(set_read_only, True)
obj = super().__new__(cls, shape=shape, dtype=dtype, **kwds)
if init is None:
assert name is not None
pretty_name = first_not_None(pretty_name, name)
assert scalar_cls is not None
assert issubclass(scalar_cls, ScalarBaseTag)
scalar_kwds = first_not_None(scalar_kwds, {})
lsep = "" if npw.less(shape, 10).all() else ","
vsep = "_"
with obj.write_context():
for idx in npw.ndindex(*shape):
name = "{}_{}".format(name, vsep.join(str(i) for i in idx))
pname = "{}{}".format(
pretty_name, "".join(subscript(i) for i in idx)
)
vname = "{}_{}".format(name, vsep.join(str(i) for i in idx))
lname = "{}_{{{}}}".format(name, lsep.join(str(i) for i in idx))
if make_scalar_kwds is None:
skwds = scalar_kwds
else:
assert callable(make_scalar_kwds)
idx_kwds = make_scalar_kwds(idx)
for k in idx_kwds.keys():
msg = f"{k} was already set in scalar_kwds."
assert k not in scalar_kwds, msg
idx_kwds.update(scalar_kwds)
skwds = idx_kwds
obj[idx] = scalar_cls(
name=name,
pretty_name=pname,
var_name=vname,
latex_name=lname,
value=value,
idx=idx,
**skwds,
)
else:
obj[...] = init
return obj
def __init__(
self,
shape,
init=None,
name=None,
pretty_name=None,
scalar_cls=None,
scalar_kwds=None,
make_scalar_kwds=None,
value=None,
set_read_only=True,
dtype=object,
**kwds,
):
super().__init__(**kwds)
[docs]
def latex(self, matrix="b", with_packages=False):
"""
Return a latex representation of this tensor.
"""
assert self.ndim <= 2
ss = ""
if with_packages:
ss += r"\usepackage{amsmath}"
ss += "\n$$"
ss += "\n" + rf"\begin{{{matrix}matrix}}"
for i in range(self.shape[0]):
if self.ndim == 1:
ss += "\n " + _latex(self[i]) + " \\\\"
else:
ss += "\n " + " & ".join(_latex(val) for val in self[i]) + " \\\\"
ss += "\n" + rf"\end{{{matrix}matrix}}"
ss += "\n$$"
return ss
[docs]
def sstr(self):
return self.elementwise_fn(sstr)
[docs]
def strrepr(self):
return self.elementwise_fn(sstrrepr)
def __str__(self):
if self.ndim == 0:
return sstr(self.tolist())
if (self.ndim == 1) and (self.size > 1):
# reshape as a vector
a = self.reshape(self.shape + (1,))
else:
a = self
return npw.array2string(a, formatter={"all": lambda x: str(x)}, separator=" ")
def __repr__(self):
return npw.array2string(self, formatter={"all": lambda x: sstrrepr(x)})
[docs]
@contextmanager
def write_context(self):
"""
Temporarily grant write access to self for the duration of the context.
Only usefull for tensors set as read-only.
"""
try:
_old_flag = self.flags.writeable
self.flags.writeable = True
yield
except:
raise
finally:
self.flags.writeable = _old_flag
[docs]
def elementwise_fn(self, fn):
"""
Apply function fn on each element of the tensor and
return the result as a Tensor.
"""
if self.ndim:
data = npw.empty_like(self)
for idx in npw.ndindex(*self.shape):
data[idx] = fn(self[idx])
else:
data = fn(self.tolist())
return data
[docs]
def __hash__(self):
"""Hash this object by its id."""
return id(self)
[docs]
def diff(self, *symbols, **assumptions):
"""Elementwise sympy.diff()."""
return self.elementwise_fn(lambda x: sm.diff(x, *symbols, **assumptions))
[docs]
def freeze(self):
"""Apply elementwise UnevaluatedExpr on each scalar expressions."""
return self.elementwise_fn(lambda x: UnevaluatedExpr(x))
[docs]
def no_split(self):
"""Apply elementwise UnsplittedExpr on each scalar expressions."""
return self.elementwise_fn(lambda x: UnsplittedExpr(x))
[docs]
def simplify(self):
"""Elementwise sympy.simplify()."""
return self.elementwise_fn(lambda x: sm.simplify(x))
[docs]
def xreplace(self, replacements):
"""Elementwise sympy.xreplace()."""
replace = {}
for k, v in replacements.items():
if isinstance(k, npw.ndarray):
for idx in npw.ndindex(*k.shape):
kk = k[idx]
if isinstance(v, npw.ndarray):
assert k.shape == v.shape
vv = v[idx]
else:
vv = v
if (kk is not None) and (vv is not None):
replace[kk] = vv
elif (k is not None) and (v is not None):
replace[k] = v
data = npw.empty_like(self)
for idx in npw.ndindex(*self.shape):
data[idx] = self[idx].xreplace(replace)
return data
[docs]
def vreplace(self):
"""Elementwise ValueHolderI.replace_holded_values on self."""
data = npw.empty_like(self)
for idx in npw.ndindex(*self.shape):
data[idx] = ValueHolderI.replace_holded_values(self[idx])
return data
[docs]
class SymbolicScalar(ScalarBase, Symbol):
"""Symbolic scalar symbol."""
pass
[docs]
class DummySymbolicScalar(ScalarBase, Dummy):
"""Symbolic scalar dummy symbol."""
pass
[docs]
class SymbolicTensor(TensorBase):
"""Symbolic tensor symbol."""
def __new__(cls, name, shape, init=None, scalar_cls=None, **kwds):
scalar_cls = first_not_None(scalar_cls, SymbolicScalar)
return super().__new__(
cls, name=name, shape=shape, init=init, scalar_cls=scalar_cls, **kwds
)
def __init__(self, name, shape, init=None, scalar_cls=None, **kwds):
super().__init__(
name=name, shape=shape, init=init, scalar_cls=scalar_cls, **kwds
)
[docs]
class DummySymbolicTensor(TensorBase):
"""Dummy symbolic tensor symbol."""
def __new__(cls, name, shape, init=None, scalar_cls=None, **kwds):
scalar_cls = first_not_None(scalar_cls, DummySymbolicScalar)
return super().__new__(
cls, name=name, shape=shape, init=init, scalar_cls=scalar_cls, **kwds
)
def __init__(self, name, shape, init=None, scalar_cls=None, **kwds):
super().__init__(
name=name, shape=shape, init=init, scalar_cls=scalar_cls, **kwds
)
[docs]
def vreplace(expr):
ValueHolderI.replace_holded_values(expr)
if __name__ == "__main__":
a = SymbolicScalar("a", value=sm.Symbol("A"))
b = DummySymbolicScalar(
"a", value=sm.Symbol("B")
) # different symbol with the same name
c = DummySymbolicScalar("a", value=[sm.Symbol("C0"), sm.Symbol("C1")], view=1)
d = SymbolicScalar("a", value=sm.Symbol("D")) # same symbol as a (hashed by name)
print(a + b + c + d)
print(ValueHolderI.replace_holded_values(a + b + c + d))
print()
A = SymbolicTensor("A", shape=(3, 3), value=12)
B = SymbolicTensor("B", shape=(3, 3), set_read_only=False, value=npw.eye(3, 3))
C = DummySymbolicTensor("C", shape=(8,))
print(A)
print(B)
print(C)
B[0, 0] = 0
B[1, 0] = -1
print()
print(A)
print(B)
print()
print(A.vreplace())
print(B.vreplace())
print()
print(A * B)
print()
print((A.dot(B)).elementwise_fn(sm.cos))
print()
print((A.dot(B)).elementwise_fn(sm.cos).diff(B[1, 1]))
print()
print(A.latex())